{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "import os\n",
    "\n",
    "# path to store/load data\n",
    "path2data=\"./data\"\n",
    "os.makedirs(path2data, exist_ok= True)\n",
    "    \n",
    "\n",
    "h, w = 64, 64\n",
    "mean = (0.5, 0.5, 0.5)\n",
    "std = (0.5, 0.5, 0.5)\n",
    "transform= transforms.Compose([\n",
    "           transforms.Resize((h,w)),\n",
    "           transforms.CenterCrop((h,w)),\n",
    "           transforms.ToTensor(),\n",
    "           transforms.Normalize(mean, std)])\n",
    "    \n",
    "train_ds=datasets.STL10(path2data, split='train', \n",
    "                        download=True,\n",
    "                        transform=transform)\n",
    "print(len(train_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "for x, _ in train_ds:\n",
    "    print(x.shape, torch.min(x), torch.max(x))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "for x,y in train_ds:\n",
    "    print(x.shape,y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms.functional import to_pil_image\n",
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.imshow(to_pil_image(0.5*x+0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "batch_size = 32\n",
    "train_dl = torch.utils.data.DataLoader(train_ds, \n",
    "                                       batch_size=batch_size, \n",
    "                                       shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x,y in train_dl:\n",
    "    print(x.shape, y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, params):\n",
    "        super(Generator, self).__init__()\n",
    "        nz = params[\"nz\"]\n",
    "        ngf = params[\"ngf\"]\n",
    "        noc = params[\"noc\"]\n",
    "        self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,\n",
    "                                         stride=1, padding=0, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(ngf * 8)\n",
    "        self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, \n",
    "                                         stride=2, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(ngf * 4)\n",
    "        self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, \n",
    "                                         stride=2, padding=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(ngf * 2)\n",
    "        self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, \n",
    "                                         stride=2, padding=1, bias=False)\n",
    "        self.bn4 = nn.BatchNorm2d(ngf)\n",
    "        self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, \n",
    "                                         stride=2, padding=1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.bn1(self.dconv1(x)))\n",
    "        x = F.relu(self.bn2(self.dconv2(x)))            \n",
    "        x = F.relu(self.bn3(self.dconv3(x)))        \n",
    "        x = F.relu(self.bn4(self.dconv4(x)))    \n",
    "        out = torch.tanh(self.dconv5(x))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_gen = {\n",
    "        \"nz\": 100,\n",
    "        \"ngf\": 64,\n",
    "        \"noc\": 3,\n",
    "        }\n",
    "model_gen = Generator(params_gen)\n",
    "device = torch.device(\"cuda:3\")\n",
    "model_gen.to(device)\n",
    "print(model_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    y= model_gen(torch.zeros(1,100,1,1, device=device))\n",
    "print(y.shape)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, params):\n",
    "        super(Discriminator, self).__init__()\n",
    "        nic= params[\"nic\"]\n",
    "        ndf = params[\"ndf\"]\n",
    "        self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)\n",
    "        self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(ndf * 2)            \n",
    "        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(ndf * 4)\n",
    "        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)\n",
    "        self.bn4 = nn.BatchNorm2d(ndf * 8)\n",
    "        self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.leaky_relu(self.conv1(x), 0.2, True)\n",
    "        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)\n",
    "        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)\n",
    "        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        \n",
    "        \n",
    "        out = torch.sigmoid(self.conv5(x))\n",
    "        return out.view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dis = {\n",
    "    \"nic\": 3,\n",
    "    \"ndf\": 64}\n",
    "model_dis = Discriminator(params_dis)\n",
    "model_dis.to(device)\n",
    "print(model_dis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    y= model_dis(torch.zeros(1,3,h,w, device=device))\n",
    "print(y.shape)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_weights(model):\n",
    "    classname = model.__class__.__name__\n",
    "    if classname.find('Conv') != -1:\n",
    "        nn.init.normal_(model.weight.data, 0.0, 0.02)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        nn.init.normal_(model.weight.data, 1.0, 0.02)\n",
    "        nn.init.constant_(model.bias.data, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_gen.apply(initialize_weights);\n",
    "model_dis.apply(initialize_weights);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Loss, Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "lr = 2e-4 \n",
    "beta1 = 0.5\n",
    "opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_label = 1\n",
    "fake_label = 0\n",
    "nz = params_gen[\"nz\"]\n",
    "num_epochs = 100\n",
    "\n",
    "\n",
    "loss_history={\"gen\": [],\n",
    "              \"dis\": []}\n",
    "\n",
    "batch_count = 0\n",
    "for epoch in range(num_epochs):\n",
    "    for xb, yb in train_dl:\n",
    "        ba_si = xb.size(0)\n",
    "        model_dis.zero_grad()\n",
    "        xb = xb.to(device)\n",
    "        yb = torch.full((ba_si,), real_label, device=device)\n",
    "        out_dis = model_dis(xb)\n",
    "        loss_r = loss_func(out_dis, yb)\n",
    "        loss_r.backward()\n",
    "\n",
    "        noise = torch.randn(ba_si, nz, 1, 1, device=device)\n",
    "        out_gen = model_gen(noise)\n",
    "        out_dis = model_dis(out_gen.detach())\n",
    "        yb.fill_(fake_label)    \n",
    "        loss_f = loss_func(out_dis, yb)\n",
    "        loss_f.backward()\n",
    "        loss_dis = loss_r + loss_f  \n",
    "        opt_dis.step()   \n",
    "\n",
    "        model_gen.zero_grad()\n",
    "        yb.fill_(real_label)  \n",
    "        out_dis = model_dis(out_gen)\n",
    "        loss_gen = loss_func(out_dis, yb)\n",
    "        loss_gen.backward()\n",
    "        opt_gen.step()\n",
    "\n",
    "        loss_history[\"gen\"].append(loss_gen.item())\n",
    "        loss_history[\"dis\"].append(loss_dis.item())\n",
    "        batch_count += 1\n",
    "        if batch_count % 100 == 0:\n",
    "            print(epoch, loss_gen.item(),loss_dis.item())\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.title(\"Loss Progress\")\n",
    "plt.plot(loss_history[\"gen\"],label=\"Gen. Loss\")\n",
    "plt.plot(loss_history[\"dis\"],label=\"Dis. Loss\")\n",
    "plt.xlabel(\"batch count\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# store models\n",
    "import os\n",
    "path2models = \"./models/\"\n",
    "os.makedirs(path2models, exist_ok=True)\n",
    "path2weights_gen = os.path.join(path2models, \"weights_gen_128.pt\")\n",
    "path2weights_dis = os.path.join(path2models, \"weights_dis_128.pt\")\n",
    "\n",
    "torch.save(model_gen.state_dict(), path2weights_gen)\n",
    "torch.save(model_dis.state_dict(), path2weights_dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploying Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = torch.load(path2weights_gen)\n",
    "model_gen.load_state_dict(weights)\n",
    "model_gen.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "with torch.no_grad():\n",
    "    fixed_noise = torch.randn(16, nz, 1, 1, device=device)\n",
    "    print(fixed_noise.shape)\n",
    "    img_fake = model_gen(fixed_noise).detach().cpu()    \n",
    "print(img_fake.shape)\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "for ii in range(16):\n",
    "    plt.subplot(4,4,ii+1)\n",
    "    plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))\n",
    "    plt.axis(\"off\")\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
